import cv2
import os
import numpy as np
from Load_util import get_drivers
import time
from skimage import io
from sklearn.cross_validation import KFold


def create_train():
    c0_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c0'
    c1_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c1'
    c2_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c2'
    c3_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c3'
    c4_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c4'
    c5_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c5'
    c6_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c6'
    c7_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c7'
    c8_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c8'
    c9_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c9'
    c_path = [c0_path, c1_path, c2_path, c3_path, c4_path, c5_path, c6_path, c7_path, c8_path, c9_path]
    save_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/npy/train'
    label = 0
    for p in c_path:
        print p
        img_arrs = []
        s_path = save_path + '/' + str(label) + '.npy'
        for img in os.listdir(p):
            img_path = p + '/' + img
            img_arr = cv2.imread(img_path)
            img_arrs.append(img_arr)
        img_arrs = np.asarray(img_arrs, dtype=np.float32)
        np.save(s_path, img_arrs)


def create_npy():
    drivers_all = get_drivers()
    train_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    labels = []
    img_arrs = []
    drivers_id = []  # drivers in train fold
    for c in os.listdir(train_path):
        begin = time.time()
        flag = 0
        train_path_c = os.path.join(train_path, c)
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            img_path = os.path.join(train_path_c, img)
            # print os.path.exists(img_path)
            # time.sleep(1)
            img_arr = io.imread(img_path)
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
            img_arr = img_arr[:, 80:560]
            img_arr = cv2.resize(img_arr, (256, 256), interpolation=cv2.INTER_AREA)
            for xy_offset in [[0, 0], [32, 0], [16, 16], [0, 32], [32, 32]]:
                drivers_id.append(drivers_all[img])
                img_arr_ = img_arr[xy_offset[0]:xy_offset[0] + 224, xy_offset[1]:xy_offset[1] + 224]
                # print img_arr_.shape
                # cv2.imshow('img_arr', img_arr)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()
                labels.append(label)
                img_arrs.append(img_arr_)
                flag += 1
        print flag
        end = time.time()
        print 'time: {0}'.format(end - begin)
    unique_drivers = sorted(list(set(drivers_id)))
    lable_narrays = np.asarray(labels, dtype=np.int32)
    image_narrays = np.asarray(img_arrs, dtype=np.float32)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    num_samples = len(image_narrays)


def create_cropped_train():
    folder_path = '/media/wjsun/delldisk/dell/wxm/Data/KaggleDDD/train'
    folder_path_save = '/media/wjsun/delldisk/dell/wxm/Data/KaggleDDD/train_cropped'
    for c in os.listdir(folder_path):
        train_path_c = os.path.join(folder_path, c)
        for img in os.listdir(train_path_c):
            img_path = os.path.join(train_path_c, img)
            # print os.path.exists(img_path)
            # time.sleep(1)
            img_arr = io.imread(img_path)
            img_arr = img_arr[:, 80:560]
            if not os.path.exists(folder_path_save + '/c' + c):
                os.mkdir(folder_path_save + '/c' + c)
            cv2.imwrite(folder_path_save + '/c' + c + '/' + img, img_arr)


def create_cropped_test():
    folder_path = '/media/wjsun/delldisk/dell/wxm/Data/KaggleDDD/test'
    folder_path_save = '/media/wjsun/delldisk/dell/wxm/Data/KaggleDDD/test_cropped'
    for img in os.listdir(folder_path):
        img_path = os.path.join(folder_path, img)
        # print os.path.exists(img_path)
        # time.sleep(1)
        img_arr = io.imread(img_path)
        img_arr = img_arr[:, 80:560]
        cv2.imwrite(folder_path_save + '/' + img, img_arr)


